In this notebook we train the static part of the cross-domain model. The static map in the Carla dataset is the day images. We train this by using a Siamese network and triplet loss. The non-grey parts of following image denote the siamese network as a part of the bigger model.

import gc
gc.collect()
209
import os
import random
import re
from tqdm import tqdm
import cv2
import torch
import torchvision
import torchvision.models as models
import torch.nn as nn
import torch.utils
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import albumentations as albu
import pandas as pd
import numpy as np
import time
import h5py
from mpl_toolkits.mplot3d import Axes3D
import plotly.express as px
from plotly.offline import init_notebook_mode, iplot
from utils.utility import *
from utils.augmentations import *
from utils.losses import *
from datasets.triplet_dataset import *
from base.base_trainer import *
from networks.siamese import *
init_notebook_mode()
pd.set_option('max_colwidth', 600)
%matplotlib inline
# For reproducibility
SEED = 42
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)
random.seed(SEED)
path_root = 'data'
path_h5 = os.path.join(path_root, 'h5')
path_noon = os.path.join(path_h5, 'noon.hdf5')
path_save = os.path.join(path_root, 'saved')
path_models = os.path.join(path_save, 'models')
path_dump = os.path.join(path_save, 'logs')
all_noon_images = get_image_arr(path_noon)
len_ds = len(all_noon_images)
transforms_static = albu.Compose([
albu.Resize(224, 224),
])
# only static tfms used for the time being as day map has no tfms
train_ds_siamese = TripletDataset(all_noon_images,
tfm=transforms_static)
valid_ds_siamese = TripletDataset([], tfm=transforms_static)
bs = 64
cuda = torch.cuda.is_available()
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
train_dl_siamese = DataLoader(train_ds_siamese, batch_size=bs, shuffle=True, **kwargs)
valid_dl_siamese = DataLoader(valid_ds_siamese, batch_size=bs, shuffle=False, **kwargs)
plot_triplet_ex(train_ds_siamese,rows=2)
margin = 3.
embedding_net = EmbeddingNet()
model = TripletNet(embedding_net)
if cuda:
model.cuda()
loss_fn = TripletLoss(margin)
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.5, last_epoch=-1)
n_epochs = 10
log_interval = bs
# path_model = os.path.join(path_models, 'static_extractor')
# checkpoint = torch.load(path_model, map_location=torch.device('cpu'))
# model.load_state_dict(checkpoint['model_state_dict'])
<All keys matched successfully>
embeddings = generate_static_embeddings(model, all_noon_images, 512)
plot_pca(embeddings)
100%|██████████| 1355/1355 [01:20<00:00, 16.79it/s]
<Figure size 720x720 with 0 Axes>
If the interactive plot is not viewable please check train_static.html in the html folder
fit(train_dl_siamese, valid_dl_siamese, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval)
embeddings = generate_static_embeddings(model, all_noon_images, 512)
plot_pca(embeddings)
100%|██████████| 1355/1355 [01:05<00:00, 20.74it/s]
<Figure size 720x720 with 0 Axes>
If the interactive plot is not viewable please check train_static.html in the html folder
It can be observed that the embeddings are better clustered in a continous manner.
path_model = os.path.join(path_models, 'static_extractor')
torch.save({
'epoch': n_epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, path_model)